import os
import sys
sys.path.append('./')
import json
import argparse
import torch
import joblib
import numpy as np
from smplx import SMPLX

import config as _C
from vis.animation import render_scene_list, renderables

import os
os.environ['PYOPENGL_PLATFORM'] = 'osmesa'

VERTS_SEGMENT_IDXS = "data/smplx_part_idxs_bcd.json"
if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()
    parser.add_argument('-s', '--sequence', default='')
    parser.add_argument('-g', '--gender', default='neutral')
    args = parser.parse_args()

    if args.sequence != '':
        _C.SEQUENCE_NAME = args.sequence

    annotation_pth = os.path.join(_C.DATA_BASE_DIR, args.sequence, f"{args.sequence}_annot.pkl")
    annotation = joblib.load(annotation_pth)
    gender = annotation["gender"]
    
    # aitviewer params
    to_tensor = lambda x: torch.from_numpy(x).float()
    kwargs = {
        'poses_root': to_tensor(annotation['global_orient']),
        'poses_body': to_tensor(annotation['body_pose']),
        'betas': to_tensor(annotation['betas']),
        'trans': to_tensor(annotation['transl']),
    }

    smplx = SMPLX(_C.SMPLX_MODEL_DIR, 
                  num_betas=len(annotation['betas'][0]), 
                  gender=gender,
    )

    # Read contact labels
    contact_labels = annotation["contact"].copy()
    colors = np.ones((annotation["global_orient"].shape[0], smplx.faces.max() + 1, 4))
    colors[..., :3] = colors[..., :3] * 0.65
    colors[..., -1:] = 0.9

    # Load segment colors
    segm_verts_dict = json.load(open(VERTS_SEGMENT_IDXS, "r"))
    
    for key, val in _C.SENSOR_NAME_MAPPER.items():
        contact_idx = _C.SENSOR_NAME_LIST.index(key)
        verts_idxs = segm_verts_dict[val]

        indices = np.ix_(contact_labels[:, contact_idx], verts_idxs)
        colors[indices] = np.array([1.0, 0.5, 0.5, 0.75])

    scene_list = []
    scene_list.append(renderables.addSMPLSequence('smplx', z_up=True, color=[0.65, 0.65, 0.65, 0.8], bm=smplx, vertex_colors=colors, **kwargs))
    render_scene_list(scene_list)